/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.test;

import org.apache.calcite.DataContexts;
import org.apache.calcite.adapter.enumerable.EnumerableConvention;
import org.apache.calcite.adapter.enumerable.EnumerableInterpretable;
import org.apache.calcite.adapter.enumerable.EnumerableRel;
import org.apache.calcite.adapter.enumerable.EnumerableRules;
import org.apache.calcite.jdbc.CalcitePrepare;
import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
import org.apache.calcite.linq4j.Enumerable;
import org.apache.calcite.linq4j.Enumerator;
import org.apache.calcite.plan.ConventionTraitDef;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRules;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepMatchOrder;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.plan.hep.HepProgramBuilder;
import org.apache.calcite.plan.volcano.VolcanoPlanner;
import org.apache.calcite.rel.RelCollationTraitDef;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.rel.rules.DpHyp;
import org.apache.calcite.rel.rules.HyperGraph;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.runtime.Bindable;
import org.apache.calcite.tools.RelBuilder;

import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Function;

/**
 * Tests for execution results of all candidate plans generated by dphyp.
 */
public class DphypJoinReorderTest {

  @Test void testInnerLeftSemiJoinReorder() {
    // select * from t1 inner join t2 on id1=id2
    // left join t3 on id2=id3
    // where exists (select * from t4 where id3=id4)
    Function<RelBuilder, RelNode> function = builder ->
        builder
            .values(new String[]{"id1", "name1"}, 1, "Tom", 2, "Lucy", 3, "Li")
            .values(new String[]{"id2", "name2"}, 1, "Tom", 2, "Lucy")
            .join(
                JoinRelType.INNER,
                builder.equals(
                    builder.field(2, 0, "id1"),
                    builder.field(2, 1, "id2")))
            .values(new String[]{"id3", "name3"}, 1, "Tom", 3, "Li", 4, "Peter")
            .join(
                JoinRelType.LEFT,
                builder.equals(
                    builder.field(2, 0, "id2"),
                    builder.field(2, 1, "id3")))
            .values(new String[]{"id4", "name4"}, 1, "Tom", 3, "Li")
            .semiJoin(
                builder.equals(
                    builder.field(2, 0, "id3"),
                    builder.field(2, 1, "id4")))
            .build();

    String expectedResult = "id1=1; name1=Tom; id2=1; name2=Tom; id3=1; name3=Tom\n";
    run(function, expectedResult);
  }

  @Test void testChainInnerJoinReorder() {
    // select * from t1 inner join t2 on id1=id2
    // inner join t3 on id2=id3
    // inner join t4 on id3=id4
    Function<RelBuilder, RelNode> function = builder ->
        builder
            .values(new String[]{"id1", "name1"}, 1, "Tom", 2, null, 3, "Li")
            .values(new String[]{"id2", "name2"}, 2, "Lucy")
            .join(
                JoinRelType.INNER,
                builder.equals(
                    builder.field(2, 0, "id1"),
                    builder.field(2, 1, "id2")))
            .values(new String[]{"id3", "name3"}, 2, "Lucy", 3, "Li", 4, "Peter")
            .join(
                JoinRelType.INNER,
                builder.equals(
                    builder.field(2, 0, "id2"),
                    builder.field(2, 1, "id3")))
            .values(new String[]{"id4", "name4"}, 2, "Lucy", 5, "Andy")
            .join(
                JoinRelType.INNER,
                builder.equals(
                    builder.field(2, 0, "id3"),
                    builder.field(2, 1, "id4")))
            .build();

    String expectedResult =
        "id1=2; name1=null; id2=2; name2=Lucy; id3=2; name3=Lucy; id4=2; name4=Lucy\n";
    run(function, expectedResult);
  }

  @Test void testStarInnerJoinReorder() {
    // select * from t1 inner join t2 on id1=id2
    // inner join t3 on id1=id3
    // inner join t4 on id1=id4
    Function<RelBuilder, RelNode> function = builder ->
        builder
            .values(new String[]{"id1", "name1"}, 1, "Tom", 2, "Lucy", 3, "Li")
            .values(new String[]{"id2", "name2"}, 1, "Tom", 2, "Lucy")
            .join(
                JoinRelType.INNER,
                builder.equals(
                    builder.field(2, 0, "id1"),
                    builder.field(2, 1, "id2")))
            .values(new String[]{"id3", "name3"}, 1, "Tom", 2, "Lucy", 3, "Li", 4, "Peter")
            .join(
                JoinRelType.INNER,
                builder.equals(
                    builder.field(2, 0, "id1"),
                    builder.field(2, 1, "id3")))
            .values(new String[]{"id4", "name4"}, 2, "Lucy", 5, "Andy", 1, "Tom")
            .join(
                JoinRelType.INNER,
                builder.equals(
                    builder.field(2, 0, "id1"),
                    builder.field(2, 1, "id4")))
            .build();

    String expectedResult =
        "id1=1; name1=Tom; id2=1; name2=Tom; id3=1; name3=Tom; id4=1; name4=Tom\n"
            + "id1=2; name1=Lucy; id2=2; name2=Lucy; id3=2; name3=Lucy; id4=2; name4=Lucy\n";
    run(function, expectedResult);
  }

  /**
   * Run the customized plan and check whether its execution results meet expectations
   * (ignore row order).
   *
   * @param customPlanFunction  a function that accepts one RelBuilder and produces a RelNode
   * @param expectedResult      expected result
   */
  void run(Function<RelBuilder, RelNode> customPlanFunction, String expectedResult) {
    VolcanoPlanner volcanoPlanner = new VolcanoPlanner();
    volcanoPlanner.addRelTraitDef(ConventionTraitDef.INSTANCE);
    volcanoPlanner.addRelTraitDef(RelCollationTraitDef.INSTANCE);

    // initialize plan
    RelDataTypeFactory typeFactory =
        new JavaTypeFactoryImpl();
    RelOptCluster cluster = RelOptCluster.create(volcanoPlanner, new RexBuilder(typeFactory));
    RelBuilder builder = RelFactories.LOGICAL_BUILDER.create(cluster, null);
    RelNode initRel = customPlanFunction.apply(builder);

    // convert to HyperGraph
    HepProgram program = new HepProgramBuilder()
        .addMatchOrder(HepMatchOrder.BOTTOM_UP)
        .addRuleInstance(CoreRules.JOIN_TO_HYPER_GRAPH)
        .build();
    HepPlanner hepPlanner = new HepPlanner(program);
    hepPlanner.setRoot(initRel);
    RelNode hyperGraph = hepPlanner.findBestExp();
    assert hyperGraph instanceof HyperGraph : "The root node must be a HyperGraph, "
        + "please check the custom plan function. The root node now is:\n"
        + RelOptUtil.toString(hyperGraph);

    // start dphyp enumeration
    DphypForTest dphyp =
        new DphypForTest(
            (HyperGraph) hyperGraph,
            builder,
            hyperGraph.getCluster().getMetadataQuery(),
            127);
    dphyp.startEnumerateJoin();

    // verify the execution results of each candidate plan enumerated by dphyp
    for (RelNode candidatePlan : dphyp.candidateList) {
      for (RelOptRule enumerableRule : EnumerableRules.rules()) {
        volcanoPlanner.addRule(enumerableRule);
      }
      // EnumerableProject does not implement the 'implement' function.
      // must replace Project with Calc to execute the plan.
      for (RelOptRule calcRule : RelOptRules.CALC_RULES) {
        volcanoPlanner.addRule(calcRule);
      }
      volcanoPlanner.removeRule(EnumerableRules.ENUMERABLE_PROJECT_RULE);

      candidatePlan =
          volcanoPlanner.changeTraits(
              candidatePlan,
              cluster.traitSet().replace(EnumerableConvention.INSTANCE));
      volcanoPlanner.setRoot(candidatePlan);
      RelNode physicalPlan = volcanoPlanner.findBestExp();
      volcanoPlanner.clear();

      Bindable bindable =
          EnumerableInterpretable.toBindable(
              Collections.emptyMap(),
              CalcitePrepare.Dummy.getSparkHandler(false),
              (EnumerableRel) physicalPlan,
              EnumerableRel.Prefer.ARRAY);
      Enumerable<Object[]> result = bindable.bind(DataContexts.EMPTY);
      Set<String> realResult =
          formatResult(result.enumerator(), physicalPlan.getRowType().getFieldNames());

      Set<String> expectedSet = new HashSet<>();
      for (String row : expectedResult.split("\n")) {
        expectedSet.add(row.trim());
      }
      assert realResult.equals(expectedSet) : "The result does not match the expected result.\n"
          + "Expected: " + expectedSet
          + "Actual: " + realResult;
    }
  }

  Set<String> formatResult(Enumerator<Object[]> enumerator, List<String> fieldNames) {
    Set<String> rowSet = new HashSet<>();
    while (enumerator.moveNext()) {
      StringBuilder output = new StringBuilder();
      Object[] row = enumerator.current();
      assert row.length == fieldNames.size() : "Result row length does not match field names size";
      for (int i = 0; i < fieldNames.size(); i++) {
        output.append(fieldNames.get(i))
            .append("=")
            .append(row[i] != null ? row[i].toString() : "null");

        if (i < fieldNames.size() - 1) {
          output.append("; ");
        }
      }
      rowSet.add(output.toString());
    }
    return rowSet;
  }

  /** A test class for {@link DpHyp} that collects candidate plans. */
  class DphypForTest extends DpHyp {

    // a list for candidate plan that include all tables in the hypergraph
    final List<RelNode> candidateList;

    DphypForTest(
        HyperGraph hyperGraph,
        RelBuilder builder,
        RelMetadataQuery relMetadataQuery,
        int bloat) {
      super(hyperGraph, builder, relMetadataQuery, bloat);
      candidateList = new ArrayList<>();
    }

    @Override protected boolean verifyDpResultRowType(
        RelNode plan,
        List<HyperGraph.NodeState> resultOrder) {
      if (hyperGraph.getInputs().size() == resultOrder.size()) {
        List<RexNode> projects =
            hyperGraph.restoreProjectionOrder(resultOrder,
                plan.getRowType().getFieldList());
        RelNode resultNode = builder
            .push(plan)
            .project(projects)
            .build();
        candidateList.add(resultNode);
        return RelOptUtil.areRowTypesEqual(resultNode.getRowType(), hyperGraph.getRowType(), false);
      }
      return true;
    }
  }

}
